import os
import copy
from PIL import Image
import numpy as np

import torch
import torch.utils.data as data
from torchvision import transforms, datasets

DATA_ROOTS = 'data'

class MNIST(data.Dataset):
    def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
        super().__init__()
        if not os.path.isdir(root):
            os.makedirs(root)
        self.image_transforms = image_transforms
        self.dataset = datasets.mnist.MNIST(root, train=train, download=True)

    def __getitem__(self, index):
        img, target = self.dataset.data[index], int(self.dataset.targets[index])
        img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
        if self.image_transforms is not None:
            img = self.image_transforms(img)
        return img, target

    def __len__(self):
        return len(self.dataset)